import numpy as np
from mytbe import executor
import csv
import argparse

def add_dynamic(file, soc_version):
    csv_reader = csv.reader(open(file))
    csv_writer = csv.writer(open('result_'+file, 'a'))
    rows = [row for row in csv_reader]
    for i in range(1, len(rows)):
        arg_data = rows[i]
        rows[i][-1] = -1
        a1 = np.random.random([int(arg_data[1]), int(arg_data[2])]).astype("float16")
        a2 = np.random.random([int(arg_data[2]), int(arg_data[3])]).astype("float16")
        a3 = np.random.random([int(arg_data[1]), int(arg_data[3])]).astype("float16")
        alpha = np.array([1]).astype("float16") 
        beta = np.array([0]).astype("float16")
        shape_range = [(1, -1), (1, -1)]

        inputs_info = [{"shape": [int(arg_data[1]), int(arg_data[2])],  "dtype": "float16", "format": "ND"},
			   {"shape": [int(arg_data[2]), int(arg_data[3])],  "dtype": "float16", "format": "ND"},
			   {"shape": [int(arg_data[1]), int(arg_data[3])],  "dtype": "float16", "format": "ND"},
			   {"shape": [1], "dtype": "float16", "format": "ND"},
			   {"shape": [1], "dtype": "float16", "format": "ND"}]

        #outputs_info = [{"shape": [-1, -1], "shape_range": shape_range, "dtype": "float16", "format": "ND"}]
        outputs_info = [{"shape": [int(arg_data[1]), int(arg_data[3])], "dtype": "float16", "format": "ND"}]

        real_out = [{"shape": [int(arg_data[1]), int(arg_data[3])], "dtype": "float16", "format": "ND"}]

        attr_dict = {"transpose_a":False, "transpose_b":False}

        executor.op_debug_level = 1

        result, runtime = executor.acl_om("GEMM", inputs_info, [a1, a2, a3, alpha, beta], outputs_info, attr_dict, soc_version=soc_version,
					      device_id=0, real_out=real_out)
        print(runtime[0])   
        rows[i][-1] = runtime[0]["task_time"]
        csv_writer.writerow(rows[i])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='gemm performance test')
    parser.add_argument('--file', default='gemm_f16.csv', type=str, help='gemm input')
    parser.add_argument('--soc_version', default='Ascend910', type=str, help='chip')
    args = parser.parse_args()
    file = args.file
    soc_version = args.soc_version
    add_dynamic(file, soc_version)

